// SPDX-License-Identifier: LGPL-3.0-or-later
#include <fcntl.h>
#include <gtest/gtest.h>
#include <sys/stat.h>
#include <sys/types.h>

#include <algorithm>
#include <cmath>
#include <fstream>
#include <vector>

#include "DeepPot.h"
#include "neighbor_list.h"
#include "test_utils.h"

// 1e-10 cannot pass; unclear bug or not
#undef EPSILON
#define EPSILON (std::is_same<VALUETYPE, double>::value ? 1e-7 : 1e-1)

template <class VALUETYPE>
class TestInferDeepPotDpaJAX : public ::testing::Test {
 protected:
  std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
                                  00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
                                  3.51,  2.51, 2.60, 4.27,  3.22, 1.56};
  std::vector<int> atype = {0, 1, 1, 0, 1, 1};
  std::vector<VALUETYPE> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
  // Generated by the following Python code:
  // import numpy as np
  // from deepmd.infer import DeepPot
  // coord = np.array([
  //     12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
  //     00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
  //     3.51,  2.51, 2.60, 4.27,  3.22, 1.56
  // ]).reshape(1, -1)
  // atype = np.array([0, 1, 1, 0, 1, 1])
  // box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1)
  // dp = DeepPot("deeppot_dpa.savedmodel")
  // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True)
  // np.set_printoptions(precision=16)
  // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=}
  // {av.ravel()=}")

  std::vector<VALUETYPE> expected_e = {-94.24098099691867,  -187.8049502787117,
                                       -187.80486052083617, -94.24059525229518,
                                       -187.80366985846246, -187.8042377490619};
  std::vector<VALUETYPE> expected_f = {
      -0.0020150115442053, -0.0133389255924977, -0.0014347177433057,
      -0.0140757358179293, 0.0031373814221557,  0.0098594354314677,
      0.004755683505073,   0.0099471082374397,  -0.0080868184532793,
      -0.0086166721574536, 0.0037803939137322,  -0.0075733131286482,
      0.0037437603038209,  -0.008452527996008,  0.0134837461840424,
      0.0162079757106944,  0.0049265700151781,  -0.0062483322902769};
  std::vector<VALUETYPE> expected_v = {
      0.0133534319524089,  0.0013445914938337,  -0.0029370551651952,
      0.0002611806151294,  0.004662662211533,   -0.0002717443796319,
      -0.0027779798869954, -0.0003277976466339, 0.0018284972283065,
      0.0085710118978246,  0.0003865036653608,  -0.0057964032875089,
      -0.0014358330222619, 0.0002912625128908,  0.001212630641674,
      -0.0050582608957046, -0.0001087907763249, 0.0040068757134429,
      0.0116736349373084,  0.0007055477968445,  -0.0019544933708784,
      0.0032997459258512,  0.0037887116116712,  -0.0043140890650835,
      -0.0034418738401156, -0.0029420616852742, 0.0038219676716965,
      0.0147134944025738,  0.0005214313829998,  -0.0006524136175906,
      0.0003656980996363,  0.0010046161607714,  -0.0017279359476254,
      0.000111127036911,   -0.0017063190420654, 0.0030174567965904,
      0.0104435705455107,  -0.0008704394438241, 0.0012354202650812,
      0.0009397615830053,  0.0029105236407293,  -0.0044188897903449,
      -0.0011461513500477, -0.0045759080125852, 0.0070310883421107,
      0.0089818851995049,  0.0038819466696704,  -0.005443705549253,
      0.0025390283635246,  0.0012121502955869,  -0.0016998728971157,
      -0.0032355117893925, -0.0015590242752438, 0.0021980725909838};
  int natoms;
  double expected_tot_e;
  std::vector<VALUETYPE> expected_tot_v;

  deepmd::DeepPot dp;

  void SetUp() override {
    dp.init("../../tests/infer/deeppot_dpa.savedmodel");

    natoms = expected_e.size();
    EXPECT_EQ(natoms * 3, expected_f.size());
    EXPECT_EQ(natoms * 9, expected_v.size());
    expected_tot_e = 0.;
    expected_tot_v.resize(9);
    std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
    for (int ii = 0; ii < natoms; ++ii) {
      expected_tot_e += expected_e[ii];
    }
    for (int ii = 0; ii < natoms; ++ii) {
      for (int dd = 0; dd < 9; ++dd) {
        expected_tot_v[dd] += expected_v[ii * 9 + dd];
      }
    }
  };

  void TearDown() override {};
};

TYPED_TEST_SUITE(TestInferDeepPotDpaJAX, ValueTypes);

TYPED_TEST(TestInferDeepPotDpaJAX, cpu_build_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial;
  dp.compute(ener, force, virial, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaJAX, cpu_build_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;
  dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);
  EXPECT_EQ(atom_ener.size(), natoms);
  EXPECT_EQ(atom_vir.size(), natoms * 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}

template <class VALUETYPE>
class TestInferDeepPotDpaJAXNopbc : public ::testing::Test {
 protected:
  std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
                                  00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
                                  3.51,  2.51, 2.60, 4.27,  3.22, 1.56};
  std::vector<int> atype = {0, 1, 1, 0, 1, 1};
  std::vector<VALUETYPE> box = {};
  // Generated by the following Python code:
  // import numpy as np
  // from deepmd.infer import DeepPot
  // coord = np.array([
  //     12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
  //     00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
  //     3.51,  2.51, 2.60, 4.27,  3.22, 1.56
  // ]).reshape(1, -1)
  // atype = np.array([0, 1, 1, 0, 1, 1])
  // box = None
  // dp = DeepPot("deeppot_dpa.savedmodel")
  // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True)
  // np.set_printoptions(precision=16)
  // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=}
  // {av.ravel()=}")

  std::vector<VALUETYPE> expected_e = {
      -94.24457967995595, -187.81100287606412, -187.81417300904738,
      -94.24110552426328, -187.80431436838532, -187.80457222464983};
  std::vector<VALUETYPE> expected_f = {
      0.0051417559595313,  -0.0021539788479118, -0.0038910585639696,
      -0.0051417559595313, 0.0021539788479118,  0.0038910585639696,
      -0.0035470615733886, 0.0003602503965239,  -0.0001895679272905,
      -0.0117361352793328, 0.0034252835112125,  -0.0071824017939095,
      0.0005398894945495,  -0.0084330745423862, 0.013284532676939,
      0.0147433073581718,  0.0046475406346498,  -0.005912562955739};
  std::vector<VALUETYPE> expected_v = {
      1.8756488030620411e-03,  -7.8574476885035112e-04, -1.4194099050199721e-03,
      -7.8574476885031816e-04, 3.2916334911298369e-04,  5.9461766291377029e-04,
      -1.4194099050199305e-03, 5.9461766291377637e-04,  1.0741480362313257e-03,
      1.9292506069911106e-03,  -8.0819957860435269e-04, -1.4599734323175548e-03,
      -8.0819957860438912e-04, 3.3857009373966171e-04,  6.1161049191681572e-04,
      -1.4599734323175878e-03, 6.1161049191681138e-04,  1.1048447595916697e-03,
      7.6085550162737362e-03,  -9.6772620145935649e-04, 8.3466357433496438e-04,
      -6.0870001592646837e-04, 7.8216372819095133e-05,  -6.6454748650881927e-05,
      1.6484218595781155e-04,  -1.4860654308892986e-05, 9.1346879754941710e-06,
      7.3496442591020616e-03,  -4.4351638231197171e-04, 1.0048648094020743e-03,
      -6.7678341212752506e-04, 1.0347359647873023e-03,  -1.6645682233463639e-03,
      1.4376416549492857e-03,  -1.7116654329527997e-03, 2.8516307661164836e-03,
      1.7683913600324828e-03,  -2.6945055858769765e-04, 3.0381043714224000e-04,
      9.9317892494769217e-04,  2.9343081937224687e-03,  -4.4381349285303948e-03,
      -1.6684804850477311e-03, -4.5760063242183471e-03, 6.9407432349694147e-03,
      7.8021639779489171e-03,  3.5246152053432442e-03,  -4.9415349551876243e-03,
      2.1362265660905197e-03,  1.2226850609509659e-03,  -1.7016886981679963e-03,
      -2.7321994901677166e-03, -1.5683141872155984e-03, 2.1960566952029213e-03};
  int natoms;
  double expected_tot_e;
  std::vector<VALUETYPE> expected_tot_v;

  deepmd::DeepPot dp;

  void SetUp() override {
    dp.init("../../tests/infer/deeppot_dpa.savedmodel");

    natoms = expected_e.size();
    EXPECT_EQ(natoms * 3, expected_f.size());
    EXPECT_EQ(natoms * 9, expected_v.size());
    expected_tot_e = 0.;
    expected_tot_v.resize(9);
    std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
    for (int ii = 0; ii < natoms; ++ii) {
      expected_tot_e += expected_e[ii];
    }
    for (int ii = 0; ii < natoms; ++ii) {
      for (int dd = 0; dd < 9; ++dd) {
        expected_tot_v[dd] += expected_v[ii * 9 + dd];
      }
    }
  };

  void TearDown() override {};
};

TYPED_TEST_SUITE(TestInferDeepPotDpaJAXNopbc, ValueTypes);

TYPED_TEST(TestInferDeepPotDpaJAXNopbc, cpu_build_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial;
  dp.compute(ener, force, virial, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaJAXNopbc, cpu_build_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;
  dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);
  EXPECT_EQ(atom_ener.size(), natoms);
  EXPECT_EQ(atom_vir.size(), natoms * 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaJAXNopbc, cpu_lmp_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial;

  std::vector<std::vector<int> > nlist_data = {
      {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5},
      {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}};
  std::vector<int> ilist(natoms), numneigh(natoms);
  std::vector<int*> firstneigh(natoms);
  deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  dp.compute(ener, force, virial, coord, atype, box, 0, inlist, 0);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaJAXNopbc, cpu_lmp_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;

  std::vector<std::vector<int> > nlist_data = {
      {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5},
      {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}};
  std::vector<int> ilist(natoms), numneigh(natoms);
  std::vector<int*> firstneigh(natoms);
  deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box, 0,
             inlist, 0);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);
  EXPECT_EQ(atom_ener.size(), natoms);
  EXPECT_EQ(atom_vir.size(), natoms * 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}
